import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms

from detr.main import build_ACT_model_and_optimizer, build_HATACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
import IPython
e = IPython.embed

class HATACTPolicy(nn.Module):
    def __init__(self, args_override):
        super().__init__()
        model, optimizer = build_HATACT_model_and_optimizer(args_override)
        self.model = model # CVAE encoder + decoder
        self.optimizer = optimizer
        self.kl_weight = args_override['kl_weight']
        print(f'KL Weight {self.kl_weight}')

    def __call__(self, qpos, image, actions=None, is_pad=None, loss_type='l1'):
        env_state = None
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        image = normalize(image)
        if actions is not None: # training time
            actions = actions[:, :self.model.num_queries]
            is_pad = is_pad[:, :self.model.num_queries]

            a_hat, is_pad_hat, cls_1, cls_2 = self.model(qpos, image, env_state, actions, is_pad)
            loss_dict = dict()
            # total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
            # total_kld_arm1, dim_wise_kld_arm1, mean_kld_arm1 = kl_divergence(mu_arm1, logvar_arm1)
            # total_kld_arm2, dim_wise_kld_arm2, mean_kld_arm2 = kl_divergence(mu_arm2, logvar_arm2)

            if loss_type == 'l1':
                all_l1 = F.l1_loss(actions, a_hat, reduction='none')
                l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
                l1_first = (all_l1[:, :50, :] * ~is_pad[:, :50].unsqueeze(-1)).mean()
                loss_dict['l1'] = l1_first
                loss_dict['l1_100'] = l1
                # loss_dict['kl_both'] = total_kld[0]
                # loss_dict['kl_arm1'] = total_kld_arm1[0]
                # loss_dict['kl_arm2'] = total_kld_arm2[0]
                loss_dict['loss'] = 0.2 * loss_dict['l1'] + loss_dict['l1_100']# + loss_dict['kl_both'] * self.kl_weight # + loss_dict['kl_arm1'] * self.kl_weight + loss_dict['kl_arm2'] * self.kl_weight
            elif loss_type == 'l2':
                all_l2 = F.mse_loss(actions, a_hat, reduction='none') 
                l2 = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
                loss_dict['loss'] = l2
            elif loss_type == 'huber':
                all_huber = F.huber_loss(actions, a_hat, reduction='none', delta=0.4)
                huber = (all_huber * ~is_pad.unsqueeze(-1)).mean()
                loss_dict['loss'] = huber
            
            return loss_dict
        else: # inference time
            a_hat, _, cls_1, cls_2 = self.model(qpos, image, env_state) # no action, sample from prior
            return a_hat[:, :50, :], cls_1, cls_2

    def configure_optimizers(self):
        return self.optimizer
    
class ACTPolicy(nn.Module):
    def __init__(self, args_override):
        super().__init__()
        model, optimizer = build_ACT_model_and_optimizer(args_override)
        self.model = model # CVAE encoder + decoder
        self.optimizer = optimizer
        self.kl_weight = args_override['kl_weight']
        print(f'KL Weight {self.kl_weight}')

    def __call__(self, qpos, image, actions=None, is_pad=None, loss_type='l1'):
        env_state = None
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        image = normalize(image)
        if actions is not None: # training time
            actions = actions[:, :self.model.num_queries]
            is_pad = is_pad[:, :self.model.num_queries]

            a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
            total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
            loss_dict = dict()
            all_l1 = F.l1_loss(actions, a_hat, reduction='none')
            l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
            loss_dict['l1'] = l1
            loss_dict['kl'] = total_kld[0]
            loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
            return loss_dict
        else: # inference time
            a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
            return a_hat

    def configure_optimizers(self):
        return self.optimizer
    
    




class CNNMLPPolicy(nn.Module):
    def __init__(self, args_override):
        super().__init__()
        model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
        self.model = model # decoder
        self.optimizer = optimizer

    def __call__(self, qpos, image, actions=None, is_pad=None):
        env_state = None # TODO
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        image = normalize(image)
        if actions is not None: # training time
            actions = actions[:, 0]
            a_hat = self.model(qpos, image, env_state, actions)
            mse = F.mse_loss(actions, a_hat)
            loss_dict = dict()
            loss_dict['mse'] = mse
            loss_dict['loss'] = loss_dict['mse']
            return loss_dict
        else: # inference time
            a_hat = self.model(qpos, image, env_state) # no action, sample from prior
            return a_hat

    def configure_optimizers(self):
        return self.optimizer

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld
